# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Train CenterNet and get network model files(.ckpt)
"""

import os
import ast
import numpy as np
import moxing as mox
import mindspore.communication.management as D
from mindspore.communication.management import get_rank
from mindspore import context, Tensor
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.nn.optim import Adam
from mindspore import log as logger
from mindspore.common import set_seed
from mindspore.profiler import Profiler

from src.dataset import COCOHP
from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
from src import CenterNetWithoutLossScaleCell
from src import CenterNetMultiPoseEval
from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
from src.model_utils.config import config, dataset_config, net_config, train_config, eval_config, export_config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_rank_id, get_device_num

_current_dir = os.path.dirname(os.path.realpath(__file__))


def _set_parallel_all_reduce_split():
    """set centernet all_reduce fusion split"""
    if net_config.last_level == 5:
        context.set_auto_parallel_context(
            all_reduce_fusion_config=[16, 56, 96, 136, 175])
    elif net_config.last_level == 6:
        context.set_auto_parallel_context(
            all_reduce_fusion_config=[18, 59, 100, 141, 182])
    else:
        raise ValueError(
            "The total num of allreduced grads for last level = {} is unknown,"
            "please re-split after known the true value".format(
                net_config.last_level))


def _get_params_groups(network, optimizer):
    """
    Get param groups
    """
    params = network.trainable_params()
    decay_params = list(filter(lambda x: not optimizer.decay_filter(x),
                               params))
    other_params = list(filter(optimizer.decay_filter, params))
    group_params = [{
        'params': decay_params,
        'weight_decay': optimizer.weight_decay
    }, {
        'params': other_params,
        'weight_decay': 0.0
    }, {
        'order_params': params
    }]
    return group_params


def _get_optimizer(network, dataset_size):
    """get optimizer, only support Adam right now."""
    if train_config.optimizer == 'Adam':
        group_params = _get_params_groups(network, train_config.Adam)
        if train_config.lr_schedule == "PolyDecay":
            lr_schedule = CenterNetPolynomialDecayLR(
                learning_rate=train_config.PolyDecay.learning_rate,
                end_learning_rate=train_config.PolyDecay.end_learning_rate,
                warmup_steps=train_config.PolyDecay.warmup_steps,
                decay_steps=config.train_steps,
                power=train_config.PolyDecay.power)
            optimizer = Adam(group_params,
                             learning_rate=lr_schedule,
                             eps=train_config.PolyDecay.eps,
                             loss_scale=1.0)
        elif train_config.lr_schedule == "MultiDecay":
            multi_epochs = train_config.MultiDecay.multi_epochs
            if not isinstance(multi_epochs, (list, tuple)):
                raise TypeError("multi_epochs must be list or tuple.")
            if not multi_epochs:
                multi_epochs = [config.epoch_size]
            lr_schedule = CenterNetMultiEpochsDecayLR(
                learning_rate=train_config.MultiDecay.learning_rate,
                warmup_steps=train_config.MultiDecay.warmup_steps,
                multi_epochs=multi_epochs,
                steps_per_epoch=dataset_size,
                factor=train_config.MultiDecay.factor)
            optimizer = Adam(group_params,
                             learning_rate=lr_schedule,
                             eps=train_config.MultiDecay.eps,
                             loss_scale=1.0)
        else:
            raise ValueError(
                "Don't support lr_schedule {}, only support [PolynormialDecay, MultiEpochDecay]"
                .format(train_config.optimizer))
    else:
        raise ValueError(
            "Don't support optimizer {}, only support [Lamb, Momentum, Adam]".
            format(train_config.optimizer))
    return optimizer


def modelarts_pre_process():
    '''modelarts pre process function.'''
    config.mindrecord_dir = config.data_path
    config.save_checkpoint_path = os.path.join(config.output_path,
                                               config.save_checkpoint_path)


def get_epoch(ckpt_name):
    start = ckpt_name.find('-')
    start += len('-')
    end = ckpt_name.find('_', start)
    steps = ast.literal_eval(ckpt_name[start:end].strip())
    epoch = int(steps//2003)
    return epoch


def get_ckpt_epoch(ckpt_dir):
    """get ckechpoint epoch"""
    ckpt_epoch = {}
    files = os.listdir(ckpt_dir)
    for file_name in files:
        file_path = os.path.join(ckpt_dir, file_name)
        if os.path.splitext(file_path)[1] == '.ckpt':
            epoch = get_epoch(file_name)
            ckpt_epoch[file_name] = epoch
    newest_ckpt = max(ckpt_epoch, key=ckpt_epoch.get)
    max_epoch = ckpt_epoch[newest_ckpt]
    return newest_ckpt, max_epoch


def run_export(ckpt_path):
    """model train function"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend", device_id=0)
    net = CenterNetMultiPoseEval(net_config, eval_config.K)
    ckpt, _ = get_ckpt_epoch(ckpt_path)
    print("ckpt: ", ckpt)
    param_dict = load_checkpoint(os.path.join(ckpt_path, ckpt))
    load_param_into_net(net, param_dict)
    net.set_train(False)
    input_shape = [1, 3, export_config.input_res[0],
                   export_config.input_res[1]]
    input_data = Tensor(np.random.uniform(-1.0, 1.0, size=input_shape).astype(np.float32))
    export(net, input_data, file_name=export_config.export_name,
           file_format=export_config.export_format)
    mox.file.copy(export_config.export_name+"."+export_config.export_format.lower(),
                  os.path.join(config.train_url, export_config.export_name+"."+export_config.export_format.lower()))


@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
    """training CenterNet"""
    context.set_context(mode=context.GRAPH_MODE,
                        device_target=config.device_target)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(save_graphs=False)

    ckpt_save_dir = config.save_checkpoint_path
    rank = 0
    device_num = 1
    num_workers = 8
    if config.device_target == "Ascend":

        context.set_context(device_id=get_device_id())
        if config.distribute == "true":
            D.init()
            device_num = get_device_num()
            rank = get_rank_id()
            ckpt_save_dir = config.save_checkpoint_path + 'ckpt_' + str(
                get_rank()) + '/'

            context.reset_auto_parallel_context()
            context.set_auto_parallel_context(
                parallel_mode=ParallelMode.DATA_PARALLEL,
                gradients_mean=True,
                device_num=device_num)
            _set_parallel_all_reduce_split()
    else:
        config.distribute = "false"
        config.need_profiler = "false"
        config.enable_data_sink = "false"

    # Start create dataset!
    # mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
    logger.info("Begin creating dataset for CenterNet")
    coco = COCOHP(dataset_config,
                  run_mode="train",
                  net_opt=net_config,
                  enable_visual_image=(config.visual_image == "true"),
                  save_path=config.save_result_dir)
    dataset = coco.create_train_dataset(config.mindrecord_dir,
                                        config.mindrecord_prefix,
                                        batch_size=train_config.batch_size,
                                        device_num=device_num,
                                        rank=rank,
                                        num_parallel_workers=num_workers,
                                        do_shuffle=config.do_shuffle == 'true')
    dataset_size = dataset.get_dataset_size()
    logger.info("Create dataset done!")

    net_with_loss = CenterNetMultiPoseLossCell(net_config)

    new_repeat_count = config.epoch_size * dataset_size // config.data_sink_steps
    if config.train_steps > 0:
        new_repeat_count = min(new_repeat_count,
                               config.train_steps // config.data_sink_steps)
    else:
        config.train_steps = config.epoch_size * dataset_size
        logger.info("train steps: {}".format(config.train_steps))

    optimizer = _get_optimizer(net_with_loss, dataset_size)

    enable_static_time = config.device_target == "CPU"
    callback = [
        TimeMonitor(config.data_sink_steps),
        LossCallBack(dataset_size, enable_static_time)
    ]
    if config.enable_save_ckpt == "true" and get_device_id() % min(
            8, device_num) == 0:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=config.save_checkpoint_steps,
            keep_checkpoint_max=config.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(
            prefix='checkpoint_centernet',
            directory=None if ckpt_save_dir == "" else ckpt_save_dir,
            config=config_ck)
        callback.append(ckpoint_cb)

    if config.load_checkpoint_path:
        param_dict = load_checkpoint(config.load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)
    if config.device_target == "Ascend":
        net_with_grads = CenterNetWithLossScaleCell(
            net_with_loss,
            optimizer=optimizer,
            sens=train_config.loss_scale_value)
    else:
        net_with_grads = CenterNetWithoutLossScaleCell(net_with_loss,
                                                       optimizer=optimizer)

    model = Model(net_with_grads)
    model.train(new_repeat_count,
                dataset,
                callbacks=callback,
                dataset_sink_mode=(config.enable_data_sink == "true"),
                sink_size=config.data_sink_steps)
    run_export(ckpt_save_dir)


if __name__ == '__main__':
    if config.need_profiler == "true":
        profiler = Profiler(output_path=config.profiler_path)
    set_seed(0)
    train()
    if config.need_profiler == "true":
        profiler.analyse()
